Skip to content

feat: refactor mcore train/forward utilities#1654

Merged
yuki-97 merged 19 commits intomainfrom
ashors/mcore-train
Feb 13, 2026
Merged

feat: refactor mcore train/forward utilities#1654
yuki-97 merged 19 commits intomainfrom
ashors/mcore-train

Conversation

@ashors1
Copy link
Contributor

@ashors1 ashors1 commented Dec 17, 2025

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Nightly test results:

============================================================
sft-llama3.1-8b-1n8g-megatron
============================================================
                                 Metric Checks
┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┓
┃ Status ┃ Check                                ┃ Value              ┃ Message ┃
┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━┩
│ PASS   │ data["train/loss"]["1"] < 0.6        │ 0.5873891115188599 │         │
│ PASS   │ data["train/loss"]["250"] < 0.36     │ 0.3181086480617523 │         │
│ PASS   │ mean(data["timing/train/total_step_… │ 12.973950623508438 │         │
│        │ 2) < 20                              │                    │         │
└────────┴──────────────────────────────────────┴────────────────────┴─────────┘

============================================================
sft-llama3.1-8b-1n8g-megatron-lora
============================================================
                                 Metric Checks
┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┓
┃ Status ┃ Check                                ┃ Value              ┃ Message ┃
┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━┩
│ PASS   │ data["train/loss"]["1"] < 1.0        │ 0.8588194251060486 │         │
│ PASS   │ data["train/loss"]["50"] < 0.8       │ 0.717399537563324  │         │
│ PASS   │ max(data["ray/node.0.gpu.0.mem_gb"]) │ 52.1015625         │         │
│        │ < 60                                 │                    │         │
│ PASS   │ mean(data["timing/train/total_step_… │ 22.606327738080704 │         │
│        │ 2) < 30                              │                    │         │
└────────┴──────────────────────────────────────┴────────────────────┴─────────┘

============================================================
sft-llama3.1-8b-1n8g-megatron-seqpack
============================================================
                                 Metric Checks
┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┓
┃ Status ┃ Check                                ┃ Value              ┃ Message ┃
┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━┩
│ PASS   │ data["train/loss"]["1"] < 0.6        │ 0.5874463319778442 │         │
│ PASS   │ data["train/loss"]["250"] < 0.36     │ 0.3183392286300659 │         │
│ PASS   │ mean(data["timing/train/total_step_… │ 5.329389685129066  │         │
│        │ 2) < 6                               │                    │         │
└────────┴──────────────────────────────────────┴────────────────────┴─────────┘

============================================================
sft-qwen2.5-math7b-2n8g-megatron
============================================================
                                 Metric Checks
┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┓
┃ Status ┃ Check                               ┃ Value               ┃ Message ┃
┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━┩
│ PASS   │ data["train/loss"]["80"] < 0.301    │ 0.29687559604644775 │         │
│ PASS   │ data["validation/val_loss"]["80"] < │ 0.30329111218452454 │         │
│        │ 0.304                               │                     │         │
└────────┴─────────────────────────────────────┴─────────────────────┴─────────┘
============================================================
dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick
============================================================
                                 Metric Checks
┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┓
┃ Status ┃ Check                                ┃ Value              ┃ Message ┃
┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━┩
│ PASS   │ data["train/loss"]["1"] < 3.6        │ 3.5130887031555176 │         │
│ PASS   │ data["train/loss"]["20"] < 3.4       │ 3.2065513134002686 │         │
│ PASS   │ data["train/preference_loss"]["1"] > │ 0.6931471824645996 │         │
│        │ 0.69314                              │                    │         │
│ PASS   │ data["train/preference_loss"]["1"] < │ 0.6931471824645996 │         │
│        │ 0.69316                              │                    │         │
│ PASS   │ data["train/preference_loss"]["20"]  │ 0.5910405516624451 │         │
│        │ < 0.6                                │                    │         │
│ PASS   │ mean(data["timing/train/total_step_… │ 4.939984655380249  │         │
│        │ -10) < 6.7                           │                    │         │
└────────┴──────────────────────────────────────┴────────────────────┴─────────┘

Issues

Closes #1593.
Closes #1744.

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

Release Notes

  • New Features

    • Added post-processing utilities for loss computation, log probability generation, and top-k logits extraction.
    • Introduced pipeline-parallel broadcasting utilities for distributed training.
    • Enhanced training infrastructure with improved forward/backward pass orchestration.
  • Refactor

    • Streamlined training workflow with consolidated post-processing architecture.
    • Refactored data handling to support improved type safety and distributed operations.
  • Tests

    • Added comprehensive unit test coverage for training utilities and post-processing pipelines.

@ashors1 ashors1 mentioned this pull request Dec 17, 2025
3 tasks
@ashors1 ashors1 force-pushed the ashors/mcore-train branch from f510015 to 2820fd4 Compare January 13, 2026 21:10
@ashors1 ashors1 requested review from cuichenx and terrykong January 15, 2026 20:15
Base automatically changed from ashors/mcore-data to main January 31, 2026 02:02
Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for all the refactoring work @ashors1

  1. I think we should merge this after @asolergi-nv looks into the padding issue in the packing path, since it would be good to check if this PR doesn't introduce a regression there

cc @ananthsub

Copy link
Contributor

@yuki-97 yuki-97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you so much for the refactor efforts! just one last step to complete. 🎉

@ananthsub ananthsub marked this pull request as ready for review February 6, 2026 09:13
@ananthsub ananthsub requested review from a team as code owners February 6, 2026 09:13
@ananthsub ananthsub added the CI:L0 Run doctests and unit tests label Feb 6, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 6, 2026

📝 Walkthrough

Walkthrough

This PR refactors Megatron-based training infrastructure by removing forward_step_arbitrary_loss, introducing new pipeline-parallel broadcasting utilities and training post-processors, and consolidating forward/backward orchestration into a unified megatron_forward_backward entry point with specialized post-processors for loss, logprobs, and top-k logits computation.

Changes

Cohort / File(s) Summary
Removed Forward Step Function
nemo_rl/models/megatron/common.py
Removed forward_step_arbitrary_loss function and related imports (GlobalState, GPTModel, parallel_state utilities), simplifying module dependencies while retaining utility functions like _round_up_to_multiple, broadcast_tensor, and get_moe_metrics.
Data Processing Updates
nemo_rl/models/megatron/data.py
Updated process_microbatch parameter type for straggler_timer to Optional[StragglerDetector] with nullcontext handling; changed process_global_batch return type from tuple to dict[str, Any] with corresponding docstring adjustments.
Pipeline Parallel Utilities
nemo_rl/models/megatron/pipeline_parallel.py
New module providing pipeline-parallel communication helpers: broadcast_obj_from_pp_rank for object broadcasting, broadcast_loss_metrics_from_last_stage for loss metric distribution, and broadcast_tensors_from_last_stage for tensor propagation across stages.
Training Infrastructure
nemo_rl/models/megatron/train.py
New module with core training utilities: model_forward for forward passes with context-parallel sharding and temperature scaling; apply_temperature_scaling for generation temperature application; forward_with_post_processing_fn for forward passes with post-processing; megatron_forward_backward for unified forward/backward orchestration; and post-processor classes (LossPostProcessor, LogprobsPostProcessor, TopkLogitsPostProcessor) for specialized output handling.
Policy Worker Refactoring
nemo_rl/models/policy/workers/megatron_policy_worker.py
Replaced manual PP broadcasting logic with new helpers; consolidated forward/backward orchestration to use megatron_forward_backward with post-processors; updated import statements and removed legacy broadcast_object_across_pp_ranks; simplified loss/logits broadcasting flow.
Test Updates
tests/unit/algorithms/test_sequence_packing_gradients.py
Updated test to use new LossPostProcessor and forward_with_post_processing_fn API instead of removed forward_step_arbitrary_loss; added straggler timer mocking scaffolding.
New Training Tests
tests/unit/models/megatron/test_train.py
Comprehensive test coverage for new training utilities including model_forward behavior, temperature scaling, forward_with_post_processing_fn integration, megatron_forward_backward orchestration, and detailed post-processor logic with packed/unpacked sequences and context-parallel normalization scenarios.

Sequence Diagram(s)

sequenceDiagram
    participant Client as Training Loop
    participant FWD as megatron_forward_backward
    participant MF as model_forward
    participant PP as Post-Processor<br/>(Loss/Logprobs/TopK)
    participant BC as Pipeline Parallel<br/>Broadcast
    participant Stages as PP Stages

    Client->>FWD: Call with data_iterator,<br/>post_processing_fn
    FWD->>FWD: Create forward_step<br/>partial
    FWD->>FWD: Call Megatron<br/>forward_backward
    Note over FWD: Executes across<br/>pipeline stages
    FWD->>MF: model_forward on stage
    MF-->>FWD: logits (on last stage)
    FWD->>PP: Apply post-processing<br/>(e.g., compute loss)
    PP-->>FWD: processed output
    FWD->>BC: broadcast_loss_metrics<br/>or broadcast_tensors
    BC->>Stages: Gather from last stage
    Stages-->>BC: Distribute to all stages
    BC-->>FWD: Broadcasted result
    FWD-->>Client: Final output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

CI:L1

Suggested reviewers

  • terrykong
  • yuki-97
  • adil-a
🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 76.60% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR contains major breaking API changes and significant refactoring but PR description is empty with only template placeholders and no testing information provided. Update PR description to include test coverage summary, confirmation of passing unit tests, justification for breaking changes, numerics/convergence verification, and performance analysis.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: refactor mcore train/forward utilities' directly matches the main changes—refactoring of Megatron training and forward utilities across multiple files, introducing new train.py module with model_forward, post-processors, and megatron_forward_backward functions.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch ashors/mcore-train

Tip

Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
nemo_rl/models/megatron/common.py (1)

60-109: ⚠️ Potential issue | 🔴 Critical

Fix device metadata handling in broadcast_tensor to avoid TypeError at line 108 and wrong-device tensor allocation.

Line 63 gets torch.cuda.current_device() (returns int), then line 108 calls torch.device(device) which will raise TypeError since torch.device() does not accept int. Additionally, device metadata is not broadcast (line 75 only includes shape and dtype), so non-source ranks always assume the default CUDA device. If the source tensor is on a non-default device, non-source ranks allocate and validate against the wrong device.

Broadcast the device in metadata and use the received device directly in comparisons:

🔧 Suggested fix
-    # Assume operations happen on the default CUDA device for the rank
-    # TODO: Consider making device explicit if needed, e.g., derive from tensor on src
-    device = torch.cuda.current_device()
+    # Device will be broadcast as part of metadata
+    device = None
@@
-        metadata = [tensor.shape, tensor.dtype]
+        metadata = [tensor.shape, tensor.dtype, tensor.device]
@@
-    received_shape, received_dtype = object_list[0]
+    received_shape, received_dtype, received_device = object_list[0]
@@
-            tensor = torch.empty(received_shape, dtype=received_dtype, device=device)
+            tensor = torch.empty(received_shape, dtype=received_dtype, device=received_device)
@@
-            if tensor.device != torch.device(device):
+            if tensor.device != received_device:
                 raise ValueError(
                     f"Rank {rank}: Provided tensor is on device {tensor.device}, "
-                    f"but expected broadcast device is {device}."
+                    f"but expected broadcast device is {received_device}."
                 )
🤖 Fix all issues with AI agents
In `@nemo_rl/models/megatron/pipeline_parallel.py`:
- Around line 1-13: Update the NVIDIA copyright header year from 2025 to 2026 at
the top of the file (the existing license block in the current module) so the
header reads 2026, ensuring the rest of the Apache License text remains
unchanged; modify the banner in nemo_rl/models/megatron/pipeline_parallel.py
accordingly.
- Around line 53-67: The code currently selects the first True in obj_flags
without ensuring uniqueness; update the logic around obj_flags/pp_size/pp_group
to validate exactly one rank owns the object by counting True entries in
obj_flags (or collecting indices) and raise an error if count != 1, then set
src_rank to the single owning rank; preserve existing behavior of raising when
none exist but also raise when multiple ranks have the object to enforce the
function's contract (use the variables obj_flags, src_rank, and pp_group to
locate and implement this check).
- Around line 135-147: The last pipeline stage currently skips broadcast calls
when a tensor value is None, causing a deadlock because other stages still call
broadcast_tensor; update the is_pipeline_last_stage branch in the loop over
tensors so that for every key you call broadcast_tensor even if tensors[name] is
None—e.g., replace None with a sentinel empty tensor/object appropriate for your
dtype/device before calling broadcast_tensor(tensor, current_rank, pp_group)
(use the same device/dtype logic as in broadcast_tensor); ensure
broadcasted_tensors[name] gets the return value so the collectives run on the
last stage as well (refer to is_pipeline_last_stage, broadcast_tensor, tensors,
broadcasted_tensors, current_rank, last_rank, pp_group).

In `@nemo_rl/models/megatron/train.py`:
- Around line 1-13: Update the NVIDIA copyright header in
nemo_rl/models/megatron/train.py by changing the year from 2025 to 2026 in the
file header; ensure the top comment block (the license header in train.py)
reflects "Copyright (c) 2026, NVIDIA CORPORATION" and keeps the rest of the
Apache 2.0 header unchanged.
- Around line 50-100: The call to apply_temperature_scaling in model_forward
unconditionally modifies logits (affecting training loss); change it so
temperature scaling runs only for inference/post‑processing paths: in
model_forward (and the similar block at lines ~165-175), detect whether
generation/inference postprocessing is active (e.g., check cfg["generation"] and
the postprocessor type or a flag like cfg.postprocessor == "inference" /
cfg.get("mode") == "inference") and only call
apply_temperature_scaling(output_tensor, cfg) when that condition is true;
otherwise leave logits unchanged for training.
- Around line 428-431: The commented-out pipeline-parallel guard around
is_pipeline_last_stage(ignore_virtual=True) and the return of
output_tensor.new_zeros(()) should be either removed or documented: either
delete those three commented lines if they are no longer needed, or replace them
with an explanatory comment that names is_pipeline_last_stage and
output_tensor.new_zeros and states why the PP guard is intentionally disabled
(e.g., because all PP stages now produce logits, testing reasons, or a temporary
debugging bypass) so future readers know the rationale.

In `@tests/unit/models/megatron/test_train.py`:
- Around line 343-351: The test assigns the return of megatron_forward_backward
to an unused variable result, triggering Ruff F841; fix it by either removing
the assignment and calling megatron_forward_backward(...) directly or by
assigning to a deliberately ignored name (e.g., _ ) so the return value is not
flagged as unused—update the call site where result is set (the
megatron_forward_backward invocation) accordingly.
🧹 Nitpick comments (1)
nemo_rl/models/megatron/data.py (1)

72-80: Align type hints with optional straggler_timer and actual return type.

process_microbatch now accepts Optional[StragglerDetector] and returns ProcessedInputs, but the upstream signatures still require a non-Optional timer, and the return annotation still advertises a tuple. This makes type checking misleading.

🔧 Suggested fix
 def make_processed_microbatch_iterator(
     raw_iterator: Iterator[BatchedDataDict[Any]],
     cfg: dict[str, Any],
     seq_length_key: Optional[str],
     pad_individual_seqs_to_multiple_of: int,
     pad_packed_seq_to_multiple_of: int,
-    straggler_timer: StragglerDetector,
+    straggler_timer: Optional[StragglerDetector],
     pad_full_seq_to: Optional[int],
 ) -> Iterator[ProcessedMicrobatch]:
@@
 def get_microbatch_iterator(
     data: BatchedDataDict[Any],
     cfg: dict[str, Any],
     mbs: int,
-    straggler_timer: StragglerDetector,
+    straggler_timer: Optional[StragglerDetector],
     seq_length_key: Optional[str] = None,
 ) -> Tuple[Iterator[ProcessedMicrobatch], int, int, int, int]:
@@
 def process_microbatch(
@@
-    straggler_timer: Optional[StragglerDetector] = None,
-) -> tuple[
-    torch.Tensor,
-    torch.Tensor,
-    Optional[torch.Tensor],
-    Optional[torch.Tensor],
-    Optional[PackedSeqParams],
-    Optional[torch.Tensor],
-]:
+    straggler_timer: Optional[StragglerDetector] = None,
+) -> ProcessedInputs:

Also applies to: 126-132, 208-216

@ananthsub ananthsub added CI:L0 Run doctests and unit tests and removed CI:L0 Run doctests and unit tests labels Feb 6, 2026
@ananthsub ananthsub added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L0 Run doctests and unit tests labels Feb 10, 2026
@ananthsub ananthsub removed the CI:L1 Run doctests, unit tests, and functional tests label Feb 10, 2026
ashors1 and others added 13 commits February 12, 2026 18:14
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 13, 2026
@yuki-97
Copy link
Contributor

yuki-97 commented Feb 13, 2026

cancelled CI for now and waiting #1902 to be merged first.

@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 13, 2026
@yuki-97 yuki-97 enabled auto-merge (squash) February 13, 2026 11:09
@yuki-97 yuki-97 merged commit 58f7c4c into main Feb 13, 2026
59 of 63 checks passed
@yuki-97 yuki-97 deleted the ashors/mcore-train branch February 13, 2026 19:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[meta issue] Refactoring Efforts MegatronPolicyWorker refactor

5 participants